﻿using System.Collections.Concurrent;

namespace PmSoft.Core.Extensions;

/// <summary>
/// Enumerable 类型的扩展方法
/// </summary>
public static class EnumerableExtensions
{
	/// <summary>
	/// 对集合中的每个元素执行指定的操作
	/// </summary>
	/// <typeparam name="T">集合中元素的类型</typeparam>
	/// <param name="source">集合</param>
	/// <param name="action">要对每个元素执行的操作</param>
	/// <returns>原始集合</returns>
	public static IEnumerable<T> ForEach<T>(this IEnumerable<T> source, Action<T> action)
	{
		foreach (T item in source)
		{
			action(item);
		}
		return source;
	}

	/// <summary>
	/// 异步对集合中的每个元素执行指定的操作
	/// </summary>
	/// <typeparam name="T">集合中元素的类型</typeparam>
	/// <param name="source">集合</param>
	/// <param name="action">要对每个元素执行的异步操作</param>
	/// <returns>表示异步操作的任务</returns>
	public static Task ForEachAsync<T>(this IEnumerable<T> source, Func<T, Task> action)
	{
		int partitionCount = Environment.ProcessorCount;

		return Task.WhenAll(
			from partition in Partitioner.Create(source).GetPartitions(partitionCount)
			select Task.Run(async delegate
			{
				using (partition)
				{
					while (partition.MoveNext())
					{
						await action(partition.Current);
					}
				}
			}));
	}

	/// <summary>
	/// 异步对集合中的每个元素执行指定的操作，并返回结果集合
	/// </summary>
	/// <typeparam name="T">集合中元素的类型</typeparam>
	/// <typeparam name="TResult">结果集合中元素的类型</typeparam>
	/// <param name="source">集合</param>
	/// <param name="action">要对每个元素执行的异步操作，返回结果</param>
	/// <returns>表示异步操作的任务，任务结果包含结果集合</returns>
	public static async Task<IEnumerable<TResult>> ForEachAsync<T, TResult>(this IEnumerable<T> source, Func<T, Task<TResult>> action)
	{
		int partitionCount = Environment.ProcessorCount;

		var results = new ConcurrentBag<TResult>();

		await Task.WhenAll(
			from partition in Partitioner.Create(source).GetPartitions(partitionCount)
			select Task.Run(async delegate
			{
				using (partition)
				{
					while (partition.MoveNext())
					{
						var result = await action(partition.Current);
						results.Add(result);
					}
				}
			}));

		return results;
	}

	/// <summary>
	/// 获取两个集合的交集（根据多个属性）
	/// </summary>
	public static IEnumerable<T> IntersectByKeys<T, TKey>(
		this IEnumerable<T> first, IEnumerable<T> second, Func<T, TKey> keySelector)
	{
		return first.Join(second, keySelector, keySelector, (f, s) => f);
	}

	/// <summary>
	/// 获取两个集合的差集（first 相对于 second 的差集，根据多个属性）
	/// </summary>
	public static IEnumerable<T> ExceptByKeys<T, TKey>(
		this IEnumerable<T> first, IEnumerable<T> second, Func<T, TKey> keySelector)
	{
		var secondKeys = second.Select(keySelector).ToHashSet();
		return first.Where(f => !secondKeys.Contains(keySelector(f)));
	}

	/// <summary>
	/// 获取两个集合的并集（根据多个属性去重）
	/// </summary>
	public static IEnumerable<T> UnionByKeys<T, TKey>(
		this IEnumerable<T> first, IEnumerable<T> second, Func<T, TKey> keySelector)
	{
		return first.Concat(second).GroupBy(keySelector).Select(g => g.First());
	}
}
